{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-pVhOfzLx9us"
   },
   "source": [
    "#Waymo Open Dataset 2D Panoramic Video Panoptic Segmentation Tutorial\n",
    "\n",
    "- Website: https://waymo.com/open\n",
    "- GitHub: https://github.com/waymo-research/waymo-open-dataset\n",
    "\n",
    "This tutorial demonstrates how to decode and interpret the 2D panoramic video panoptic segmentation labels. Visit the [Waymo Open Dataset Website](https://waymo.com/open) to download the full dataset.\n",
    "\n",
    "## Dataset\n",
    "This dataset contains panoptic segmentation labels for a subset of the Open\n",
    "Dataset camera images. In addition, we provide associations for instances between different camera images and over time, allowing for panoramic video panoptic segmentation.\n",
    "\n",
    "For the training set, we provide tracked sequences of 5 temporal frames, spaced at t=[0ms, 400ms, 600ms, 800ms, 1200ms]. For each labeled time step, we  label all 5 cameras around the Waymo vehicle, resulting in a total of 25 labeled images per sequence. This allows for tracking over a variety of different time frames and viewpoints. \n",
    "\n",
    "In particular, this non-uniform spacing maximizes the temporal span of each set of frames (maximize the dt within each set) for larger temporal displacements, while providing some temporally neighboring frames to make learning easier. These two sets of dts also increases the diversity between pairs of frames while respecting the previous two goals.\n",
    "\n",
    "For the validation set, we label entire run segments at 5Hz (every other image), resulting in sequences of 100 temporal frames over 5 cameras (500 labels per sequence).\n",
    "\n",
    "## Instructions\n",
    "This colab will demonstrate how to read the labels, and to extract panoptic labels with consistent instance ID tracks for any number of frames.\n",
    "\n",
    "To run, use this [colab link](https://colab.research.google.com/github/waymo-research/waymo-open-dataset/blob/master/tutorial/tutorial_maps.ipynb) to open directly in colab."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "f4PlTs1o3f5A"
   },
   "source": [
    "## Install Waymo Open Dataset Package\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1sPLur9kMaLh"
   },
   "source": [
    "# Package Installation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "kTP04tYwNOIq"
   },
   "outputs": [],
   "source": [
    "!pip3 install waymo-open-dataset-tf-2-12-0==1.6.7"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "s31F-VVWLa7z"
   },
   "outputs": [],
   "source": [
    "#@markdown ## (Optional) Install a minimal version of deeplab2 for the WOD.\n",
    "# Please ignore this cell if you already have deeplab2 installed\n",
    "\n",
    "# This shell script will download and install only those deeplab2 modules which\n",
    "# are used by the WOD.\n",
    "# They are used here https://github.com/waymo-research/waymo-open-dataset/blob/master/src/waymo_open_dataset/bazel/deeplab2.BUILD\n",
    "!wget https://raw.githubusercontent.com/waymo-research/waymo-open-dataset/master/src/waymo_open_dataset/pip_pkg_scripts/install_deeplab2.sh -O - | bash\n",
    "\n",
    "# Refer to the instructions on how intall the entire deeplab2 if you need other\n",
    "# deeplab2 modules as well.\n",
    "# https://github.com/google-research/deeplab2/blob/main/g3doc/setup/installation.md"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rqs8_62VNc4T"
   },
   "source": [
    "# Imports and global definitions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "YuNAlbQpNkLa"
   },
   "outputs": [],
   "source": [
    "#@markdown ## Data location. Please edit.\n",
    "\n",
    "#@markdown #### Replace this path with your own tfrecords. A tfrecord containing tf.Example protos as downloaded from the Waymo dataset webpage.\n",
    "FILE_NAME = '/content/waymo-open-dataset/tutorial/.../tfexample.tfrecord' #@param {type:\"string\"}\n",
    "\n",
    "#@markdown #### Replace this directory with your own dataset folder for evaluation.\n",
    "EVAL_DIR = '/dataset_path/validation/' #@param {type: \"string\"}\n",
    "EVAL_RUNS = ['segment-1024360143612057520_3580_000_3600_000_with_camera_labels.tfrecord','segment-11048712972908676520_545_000_565_000_with_camera_labels.tfrecord'] #@param {type: \"raw\"}\n",
    "\n",
    "#@markdown #### Replace this path with the real path. Each line of the file is the \"\\<context_name>, \\<timestamp_micros>\" of a frame with camera segmentation labels.\n",
    "TEST_SET_SOURCE = '/content/waymo-open-dataset/tutorial/2d_pvps_validation_frames.txt' #@param {type: \"string\"} \n",
    "\n",
    "#@markdown #### Replace this directory with your own testing dataset folder.\n",
    "TEST_DIR = '/dataset_path/testing/' #@param {type: \"string\"}\n",
    "\n",
    "#@markdown #### Replace this directory with your own local folder saving the submission.\n",
    "SAVE_FOLDER = '/tmp/camera_segmentation_challenge/testing/' #@param {type: \"string\"}\n",
    "\n"]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xCDNLdp9Ni8a"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "from typing import Any, Dict, Iterator, List, Optional, Sequence, Tuple\n",
    "import immutabledict\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf\n",
    "import multiprocessing as mp\n",
    "import numpy as np\n",
    "import dask.dataframe as dd\n",
    "\n",
    "if not tf.executing_eagerly():\n",
    "  tf.compat.v1.enable_eager_execution()\n",
    "\n",
    "from waymo_open_dataset import dataset_pb2 as open_dataset\n",
    "from waymo_open_dataset import v2\n",
    "from waymo_open_dataset.protos import camera_segmentation_metrics_pb2 as metrics_pb2\n",
    "from waymo_open_dataset.protos import camera_segmentation_submission_pb2 as submission_pb2\n",
    "from waymo_open_dataset.wdl_limited.camera_segmentation import camera_segmentation_metrics\n",
    "from waymo_open_dataset.utils import camera_segmentation_utils"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ibor0U9XBlX6"
   },
   "source": [
    "# Read 2D panoptic segmentation labels from Frame proto\n",
    "Note that only a subset of the frames have 2D panoptic labels."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "O41R3lljM9Ym"
   },
   "outputs": [],
   "source": [
    "dataset = tf.data.TFRecordDataset(FILE_NAME, compression_type='')\n",
    "frames_with_seg = []\n",
    "sequence_id = None\n",
    "for data in dataset:\n",
    "  frame = open_dataset.Frame()\n",
    "  frame.ParseFromString(bytearray(data.numpy()))\n",
    "  # Save frames which contain CameraSegmentationLabel messages. We assume that\n",
    "  # if the first image has segmentation labels, all images in this frame will.\n",
    "  if frame.images[0].camera_segmentation_label.panoptic_label:\n",
    "    frames_with_seg.append(frame)\n",
    "    if sequence_id is None:\n",
    "      sequence_id = frame.images[0].camera_segmentation_label.sequence_id\n",
    "    # Collect 3 frames for this demo. However, any number can be used in practice.\n",
    "    if frame.images[0].camera_segmentation_label.sequence_id != sequence_id or len(frames_with_seg) > 2:\n",
    "      break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "wHK95_JBUXUx"
   },
   "outputs": [],
   "source": [
    "# Organize the segmentation labels in order from left to right for viz later.\n",
    "camera_left_to_right_order = [open_dataset.CameraName.SIDE_LEFT,\n",
    "                              open_dataset.CameraName.FRONT_LEFT,\n",
    "                              open_dataset.CameraName.FRONT,\n",
    "                              open_dataset.CameraName.FRONT_RIGHT,\n",
    "                              open_dataset.CameraName.SIDE_RIGHT]\n",
    "segmentation_protos_ordered = []\n",
    "for frame in frames_with_seg:\n",
    "  segmentation_proto_dict = {image.name : image.camera_segmentation_label for image in frame.images}\n",
    "  segmentation_protos_ordered.append([segmentation_proto_dict[name] for name in camera_left_to_right_order])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Osrj5Gr6TOYr"
   },
   "source": [
    "# Read 2D panoptic segmentation labels from v2 components\n",
    "\n",
    "This section aims to replicate the functionality provided above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "gS6MKTorSQrT"
   },
   "outputs": [],
   "source": [
    "context_name = '550171902340535682_2640_000_2660_000'\n",
    "def read(tag: str, dataset_dir: str = EVAL_DIR) -> dd.DataFrame:\n",
    "  \"\"\"Creates a Dask DataFrame for the component specified by its tag.\"\"\"\n",
    "  paths = f'{dataset_dir}/{tag}/{context_name}.parquet'\n",
    "  return dd.read_parquet(paths)\n",
    "\n",
    "cam_segmentation_df = read('camera_segmentation')\n",
    "\n",
    "# Group segmentation labels into frames by context name and timestamp.\n",
    "frame_keys = ['key.segment_context_name', 'key.frame_timestamp_micros']\n",
    "cam_segmentation_per_frame_df = cam_segmentation_df.groupby(\n",
    "    frame_keys, group_keys=False).agg(list)\n",
    "\n",
    "def ungroup_row(key_names: Sequence[str],\n",
    "                key_values: Sequence[str],\n",
    "                row: dd.DataFrame) -> Iterator[Dict[str, Any]]:\n",
    "  \"\"\"Splits a group of dataframes into individual dicts.\"\"\"\n",
    "  keys = dict(zip(key_names, key_values))\n",
    "  cols, cells = list(zip(*[(col, cell) for col, cell in r.items()]))\n",
    "  for values in zip(*cells):\n",
    "    yield dict(zip(cols, values), **keys)\n",
    "\n",
    "cam_segmentation_list = []\n",
    "for i, (key_values, r) in enumerate(cam_segmentation_per_frame_df.iterrows()):\n",
    "  # Read three sequences of 5 camera images for this demo.\n",
    "  if i >= 3:\n",
    "    break\n",
    "  # Store a segmentation label component for each camera.\n",
    "  cam_segmentation_list.append(\n",
    "      [v2.CameraSegmentationLabelComponent.from_dict(d) \n",
    "       for d in ungroup_row(frame_keys, key_values, r)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "VMsODRP4YrHj"
   },
   "outputs": [],
   "source": [
    "# Order labels from left to right for visualization later.\n",
    "# For each frame with segmentation labels, all cameras should have a label.\n",
    "camera_left_to_right_order = [open_dataset.CameraName.SIDE_LEFT,\n",
    "                              open_dataset.CameraName.FRONT_LEFT,\n",
    "                              open_dataset.CameraName.FRONT,\n",
    "                              open_dataset.CameraName.FRONT_RIGHT,\n",
    "                              open_dataset.CameraName.SIDE_RIGHT]\n",
    "segmentation_protos_ordered = []\n",
    "for it, label_list in enumerate(cam_segmentation_list):\n",
    "  segmentation_dict = {label.key.camera_name: label for label in label_list}\n",
    "  segmentation_protos_ordered.append([segmentation_dict[name] for name in camera_left_to_right_order])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wcDCYUF8Y1pY"
   },
   "source": [
    "# Read a single panoptic label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "zYtrSEkXgpH8"
   },
   "outputs": [],
   "source": [
    "# Decode a single panoptic label.\n",
    "panoptic_label_front = camera_segmentation_utils.decode_single_panoptic_label_from_proto(\n",
    "    segmentation_protos_ordered[0][open_dataset.CameraName.FRONT]\n",
    ")\n",
    "\n",
    "# Separate the panoptic label into semantic and instance labels.\n",
    "semantic_label_front, instance_label_front = camera_segmentation_utils.decode_semantic_and_instance_labels_from_panoptic_label(\n",
    "    panoptic_label_front,\n",
    "    segmentation_protos_ordered[0][open_dataset.CameraName.FRONT].panoptic_label_divisor\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6Cm8lagZY3ip"
   },
   "source": [
    "# Read panoptic labels with consistent instance IDs over cameras and time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "0IEvJgL9gMdR"
   },
   "outputs": [],
   "source": [
    "# The dataset provides tracking for instances between cameras and over time.\n",
    "# By setting remap_to_global=True, this function will remap the instance IDs in\n",
    "# each image so that instances for the same object will have the same ID between\n",
    "# different cameras and over time.\n",
    "segmentation_protos_flat = sum(segmentation_protos_ordered, [])\n",
    "panoptic_labels, num_cameras_covered, is_tracked_masks, panoptic_label_divisor = camera_segmentation_utils.decode_multi_frame_panoptic_labels_from_segmentation_labels(\n",
    "    segmentation_protos_flat, remap_to_global=True\n",
    ")\n",
    "\n",
    "# We can further separate the semantic and instance labels from the panoptic\n",
    "# labels.\n",
    "NUM_CAMERA_FRAMES = 5\n",
    "semantic_labels_multiframe = []\n",
    "instance_labels_multiframe = []\n",
    "for i in range(0, len(segmentation_protos_flat), NUM_CAMERA_FRAMES):\n",
    "  semantic_labels = []\n",
    "  instance_labels = []\n",
    "  for j in range(NUM_CAMERA_FRAMES):\n",
    "    semantic_label, instance_label = camera_segmentation_utils.decode_semantic_and_instance_labels_from_panoptic_label(\n",
    "      panoptic_labels[i + j], panoptic_label_divisor)\n",
    "    semantic_labels.append(semantic_label)\n",
    "    instance_labels.append(instance_label)\n",
    "  semantic_labels_multiframe.append(semantic_labels)\n",
    "  instance_labels_multiframe.append(instance_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "c6gpEl0if3nu"
   },
   "source": [
    "# Visualize the panoptic segmentation labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Hj5lXAlJXjKk"
   },
   "outputs": [],
   "source": [
    "def _pad_to_common_shape(label):\n",
    "  return np.pad(label, [[1280 - label.shape[0], 0], [0, 0], [0, 0]])\n",
    "\n",
    "# Pad labels to a common size so that they can be concatenated.\n",
    "instance_labels = [[_pad_to_common_shape(label) for label in instance_labels] for instance_labels in instance_labels_multiframe]\n",
    "semantic_labels = [[_pad_to_common_shape(label) for label in semantic_labels] for semantic_labels in semantic_labels_multiframe]\n",
    "instance_labels = [np.concatenate(label, axis=1) for label in instance_labels]\n",
    "semantic_labels = [np.concatenate(label, axis=1) for label in semantic_labels]\n",
    "\n",
    "instance_label_concat = np.concatenate(instance_labels, axis=0)\n",
    "semantic_label_concat = np.concatenate(semantic_labels, axis=0)\n",
    "panoptic_label_rgb = camera_segmentation_utils.panoptic_label_to_rgb(\n",
    "    semantic_label_concat, instance_label_concat)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fgCDPt9zeV_k"
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(64, 60))\n",
    "plt.imshow(panoptic_label_rgb)\n",
    "plt.grid(False)\n",
    "plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9ALsPTBuOXJj"
   },
   "source": [
    "# Evaluate Panoptic Predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "f6B9BdAuZLyN"
   },
   "outputs": [],
   "source": [
    "#@title Metric computation utility functions\n",
    "\n",
    "def _run_dummy_inference_from_protos(\n",
    "    image_proto_list: Sequence[bytes]\n",
    ") -> List[np.ndarray]:\n",
    "  \"\"\"Creates dummy predictions from protos.\"\"\"\n",
    "  panoptic_preds = []\n",
    "  for image_proto in image_proto_list:\n",
    "    image_array = tf.image.decode_jpeg(image_proto).numpy()\n",
    "    # Creates a dummy prediction by setting the panoptic labels to 0 for all pixels.\n",
    "    panoptic_pred = np.zeros(\n",
    "        (image_array.shape[0], image_array.shape[1], 1), dtype=np.int32)\n",
    "    panoptic_preds.append(panoptic_pred)\n",
    "  return panoptic_preds\n",
    "\n",
    "\n",
    "def _compute_metric_for_dataset(filename: str):\n",
    "  \"\"\"Computes metric for the dataset frames.\"\"\"\n",
    "  eval_config = camera_segmentation_metrics.get_eval_config()\n",
    "  new_panoptic_label_divisor = eval_config.panoptic_label_divisor\n",
    "\n",
    "  dataset = tf.data.TFRecordDataset(filename, compression_type='')\n",
    "  # Load first 3 frames in the demo.\n",
    "  frames_with_seg, sequence_id = (\n",
    "      camera_segmentation_utils.load_frames_with_labels_from_dataset(dataset, 3))\n",
    "\n",
    "  segmentation_protos_ordered = []\n",
    "  image_protos_ordered = []\n",
    "  # Only aggregates frames with camera segmentation labels.\n",
    "  for frame in frames_with_seg:\n",
    "    segmentation_proto_dict = {image.name : image.camera_segmentation_label for image in frame.images}\n",
    "    segmentation_protos_ordered.append([segmentation_proto_dict[name] for name in camera_left_to_right_order])\n",
    "    image_proto_dict = {image.name: image.image for image in frame.images}\n",
    "    image_protos_ordered.append([image_proto_dict[name] for name in camera_left_to_right_order])\n",
    "\n",
    "  # The dataset provides tracking for instances between cameras and over time.\n",
    "  # By setting remap_to_global=True, this function will remap the instance IDs in\n",
    "  # each image so that instances for the same object will have the same ID between\n",
    "  # different cameras and over time.\n",
    "  segmentation_protos_flat = sum(segmentation_protos_ordered, [])\n",
    "  image_protos_flat = sum(image_protos_ordered, [])\n",
    "  decoded_elements = (\n",
    "      camera_segmentation_utils.decode_multi_frame_panoptic_labels_from_segmentation_labels(\n",
    "          segmentation_protos_flat, remap_to_global=True,\n",
    "          new_panoptic_label_divisor=new_panoptic_label_divisor)\n",
    "  )\n",
    "  panoptic_labels, num_cameras_covered, is_tracked_masks = decoded_elements[0:3]\n",
    "  \n",
    "  # We provide a dummy inference function in the demo. Please replace this with \n",
    "  # your own method. It is recommended to generate your own panoptic labels first\n",
    "  # and implement a function to load the generated panoptic labels from the disk.\n",
    "  panoptic_preds = _run_dummy_inference_from_protos(image_protos_flat)\n",
    "  return camera_segmentation_metrics.get_metric_object_by_sequence(\n",
    "    true_panoptic_labels=panoptic_labels,\n",
    "    pred_panoptic_labels=panoptic_preds,\n",
    "    num_cameras_covered=num_cameras_covered,\n",
    "    is_tracked_masks=is_tracked_masks,\n",
    "    sequence_id=sequence_id,\n",
    "  )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ki4xdqN2QYG1"
   },
   "outputs": [],
   "source": [
    "#@title Metric computation for a single sequence\n",
    "\n",
    "metric_object = _compute_metric_for_dataset(FILE_NAME)\n",
    "single_sequence_result = camera_segmentation_metrics.aggregate_metrics(\n",
    "    [metric_object])\n",
    "print('Metrics:')\n",
    "print(single_sequence_result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "eqxLTfmPZWJO"
   },
   "outputs": [],
   "source": [
    "#@title Metric computation for multiple sequences (sequential)\n",
    "\n",
    "eval_filenames = [os.path.join(EVAL_DIR, eval_run) for eval_run in EVAL_RUNS]\n",
    "\n",
    "# For Loop version.\n",
    "multi_sequence_result = camera_segmentation_metrics.aggregate_metrics(\n",
    "    [_compute_metric_for_dataset(filename) for filename in eval_filenames]\n",
    ")\n",
    "print('Metrics:')\n",
    "print(multi_sequence_result)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "vLd6IaPSsQy-"
   },
   "outputs": [],
   "source": [
    "#@title Metric computation for multiple sequences (multiprocessing)\n",
    "\n",
    "eval_filenames = [os.path.join(EVAL_DIR, eval_run) for eval_run in EVAL_RUNS]\n",
    "\n",
    "# We provide a simple resolution to parallize the metric computation by each\n",
    "# sequence through multiprocess. However, the performance varies depending on\n",
    "# the machine or clusters. It is recommended to parallelize the metric\n",
    "# computation as much as possible, depending on your computational resources.\n",
    "with mp.Pool(processes=2) as pool:\n",
    "  metric_objects = pool.imap(_compute_metric_for_dataset, eval_filenames)\n",
    "  pool.close()\n",
    "  pool.join()\n",
    "\n",
    "multi_sequence_result = camera_segmentation_metrics.aggregate_metrics(\n",
    "    metric_objects)\n",
    "print('Metrics:')\n",
    "print(multi_sequence_result)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rpz_j4B9xRXv"
   },
   "source": [
    "# Generate Submission"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "uwlvldaa_Oo3"
   },
   "outputs": [],
   "source": [
    "#@title Submission helper functions\n",
    "\n",
    "def _make_submission_proto(\n",
    ") -> submission_pb2.CameraSegmentationSubmission:\n",
    "  \"\"\"Makes a submission proto to store predictions for one shard.\"\"\"\n",
    "  submission = submission_pb2.CameraSegmentationSubmission()\n",
    "  submission.account_name = 'me@gmail.com'\n",
    "  submission.unique_method_name = 'My method'\n",
    "  submission.authors.extend(['Author 1', 'Author 2', 'Author 3'])\n",
    "  submission.affiliation = 'My institute'\n",
    "  submission.description = 'Description of my method'\n",
    "  submission.method_link = 'http://example.com/'\n",
    "  submission.frame_dt = 1\n",
    "  submission.runtime_ms = 1000\n",
    "  return submission\n",
    "\n",
    "\n",
    "def _load_dataset_for_one_test_shard(\n",
    "    filename: str, \n",
    "    filter_by_timestamps: bool = False,\n",
    "    valid_timestamps: Optional[Sequence[str]] = None,\n",
    ") -> Tuple[List[open_dataset.Frame], str]:\n",
    "  \"\"\"Loads (subsampled) dataset frames and sequence id for evaluation.\"\"\"\n",
    "  dataset = tf.data.TFRecordDataset(filename, compression_type='')\n",
    "  test_frames = []\n",
    "  context_name = ''\n",
    "\n",
    "  for index, data in enumerate(dataset):\n",
    "    frame = open_dataset.Frame()\n",
    "    frame.ParseFromString(bytearray(data.numpy()))\n",
    "    \n",
    "    if not context_name:\n",
    "      context_name = frame.context.name\n",
    "    # Skip frames if `filter_by_timestamps` is set to True and \n",
    "    # corresponding timestamp is in the `valid_timestamps` list. Otherwise, keep all the frames\n",
    "    # from a test sequence in the submission.\n",
    "    if filter_by_timestamps and not str(frame.timestamp_micros) in valid_timestamps:\n",
    "      continue\n",
    "    test_frames.append(frame)\n",
    "      \n",
    "  return test_frames, context_name\n",
    "\n",
    "\n",
    "def _generate_predictions_for_one_test_shard(\n",
    "    submission: submission_pb2.CameraSegmentationSubmission,\n",
    "    filename: str,\n",
    "    valid_timestamps: Optional[Sequence[str]] = None,\n",
    ") -> None:\n",
    "  \"\"\"Iterate over all test frames in one sequence and generate predictions.\"\"\"\n",
    "  test_frames, context_name = _load_dataset_for_one_test_shard(\n",
    "      filename,\n",
    "      filter_by_timestamps=True if valid_timestamps else False,\n",
    "      valid_timestamps=valid_timestamps)\n",
    "  image_protos_ordered = []\n",
    "  frame_timestamps_ordered = []\n",
    "  camera_names_ordered = []\n",
    "  print(f'Loading test sequence with context {context_name}...')\n",
    "  for frame in test_frames:\n",
    "    image_proto_dict = {image.name: image.image for image in frame.images}\n",
    "    image_protos_ordered.append(\n",
    "        [image_proto_dict[name] for name in camera_left_to_right_order])\n",
    "    frame_timestamps_dict = {image.name: frame.timestamp_micros for image in frame.images}\n",
    "    frame_timestamps_ordered.append(\n",
    "        [frame_timestamps_dict[name] for name in camera_left_to_right_order])\n",
    "    camera_names_ordered.append(\n",
    "        [name for name in camera_left_to_right_order])\n",
    "\n",
    "  print(f'Processing test sequence with context {context_name}...')\n",
    "  image_protos_flat = sum(image_protos_ordered, [])\n",
    "  frame_timestamps_flat = sum(frame_timestamps_ordered, [])\n",
    "  camera_names_flat = sum(camera_names_ordered, [])\n",
    "  # We provide a dummy inference function in the demo. Please replace this with \n",
    "  # your own method. It is recommended to generate your own panoptic labels first\n",
    "  # and implement a function to load the generated panoptic labels from the disk.\n",
    "  panoptic_preds = _run_dummy_inference_from_protos(image_protos_flat)\n",
    "  # The `panoptic_label_divisor` must be greater than the largest number of \n",
    "  # instances in a single frame.\n",
    "  panoptic_label_divisor = 1000\n",
    "  for panoptic_pred, frame_timestamp, camera_name in zip(\n",
    "      panoptic_preds, \n",
    "      frame_timestamps_flat, \n",
    "      camera_names_flat):\n",
    "    # In the tutorial, we simply use the context name as the unique identifier \n",
    "    # for the test sequence. Note this `sequence_id` will not be used for eval,\n",
    "    # since we only compare the `context_name` in the CameraSegmentationFrame.\n",
    "    label_sequence_id = context_name\n",
    "    seg_proto = camera_segmentation_utils.save_panoptic_label_to_proto(\n",
    "        panoptic_pred,\n",
    "        panoptic_label_divisor,\n",
    "        label_sequence_id)\n",
    "    seg_frame = metrics_pb2.CameraSegmentationFrame(\n",
    "        camera_segmentation_label=seg_proto,\n",
    "        context_name=context_name,\n",
    "        frame_timestamp_micros=frame_timestamp,\n",
    "        camera_name=camera_name\n",
    "    )\n",
    "    submission.predicted_segmentation_labels.frames.extend([seg_frame])\n",
    "\n",
    "\n",
    "def _save_submission_to_file(\n",
    "    submission: submission_pb2.CameraSegmentationSubmission,\n",
    "    filename: str,\n",
    "    save_folder: str = SAVE_FOLDER,\n",
    ") -> None:\n",
    "  \"\"\"Save predictions for one sequence as a binary protobuf.\"\"\"\n",
    "  os.makedirs(save_folder, exist_ok=True)\n",
    "  basename = os.path.basename(filename)\n",
    "  if '.tfrecord' not in basename:\n",
    "    raise ValueError('Cannot determine file path for saving submission.')\n",
    "  submission_basename = basename.replace('_with_camera_labels.tfrecord',\n",
    "                                         '_camera_segmentation_submission.binproto')\n",
    "  submission_file_path = os.path.join(save_folder, submission_basename)\n",
    "  print(f'Saving predictions to {submission_file_path}...\\n')\n",
    "  f = open(submission_file_path, 'wb')\n",
    "  f.write(submission.SerializeToString())\n",
    "  f.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "VCowlAtjZHnR"
   },
   "outputs": [],
   "source": [
    "#@title Generate a submission file for a single sequence\n",
    "\n",
    "if os.path.isdir(SAVE_FOLDER):\n",
    "  os.removedirs(SAVE_FOLDER)\n",
    "\n",
    "submission = _make_submission_proto()\n",
    "print('Submission proto size: ', len(submission.SerializeToString()))\n",
    "# Here, we do not set the valid_timestamps assuming all the frames in a test\n",
    "# sequence will be included.\n",
    "# However, in the final submission please consider only include the requested\n",
    "# frames to reduce the submission package size.\n",
    "_generate_predictions_for_one_test_shard(submission, FILE_NAME)\n",
    "print('Submission proto size: ', len(submission.SerializeToString()))\n",
    "_save_submission_to_file(submission, FILE_NAME, SAVE_FOLDER)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "_7p3Bx_md_LZ"
   },
   "outputs": [],
   "source": [
    "#@title Generate submission files for multiple sequences in the test set\n",
    "\n",
    "if os.path.isdir(SAVE_FOLDER):\n",
    "  os.removedirs(SAVE_FOLDER)\n",
    "\n",
    "context_name_timestamp_tuples = [x.rstrip().split(',') for x in (\n",
    "    tf.io.gfile.GFile(TEST_SET_SOURCE, 'r').readlines())]\n",
    "\n",
    "context_names_dict = {}\n",
    "for context_name, timestamp_micros in context_name_timestamp_tuples:\n",
    "  if context_name not in context_names_dict:\n",
    "    context_names_dict.update({context_name: [timestamp_micros]})\n",
    "  else:\n",
    "    context_names_dict[context_name].append(timestamp_micros)\n",
    "\n",
    "# We save each test sequence in an indepedent submission file.\n",
    "for context_name in context_names_dict.keys():\n",
    "  test_filename = os.path.join(TEST_DIR, f'segment-{context_name}_with_camera_labels.tfrecord')\n",
    "  if not tf.io.gfile.exists(test_filename):\n",
    "    raise ValueError(f'Missing .tfrecord {context_name} under {TEST_DIR}.')\n",
    "  submission = _make_submission_proto()\n",
    "  print('Submission proto size: ', len(submission.SerializeToString()))\n",
    "  # We only include frames with timestamps requested from the .txt file in \n",
    "  # the submission.\n",
    "  _generate_predictions_for_one_test_shard(\n",
    "      submission, \n",
    "      test_filename,\n",
    "      valid_timestamps=context_names_dict[context_name])\n",
    "  print('Submission proto size: ', len(submission.SerializeToString()))\n",
    "  _save_submission_to_file(submission, test_filename, SAVE_FOLDER)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "G1JGZnNnOY9J"
   },
   "source": [
    "## Packaging submission files "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "XgGSt3p0OmO2"
   },
   "outputs": [],
   "source": [
    "!tar czvf /tmp/camera_segmentation_challenge/submit_testing.tar.gz -C $SAVE_FOLDER ."
   ]
  }
 ],
 "metadata": {
  "colab": {
   "name": "Waymo Open Dataset 2D Panoramic Video Panoptic Segmentation Tutorial.ipynb",
   "private_outputs": true,
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
